/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com> 
 *
 */
#ifndef SRC_PROTOCOLS_TCP_TCPPROTOCOL_H_
#define SRC_PROTOCOLS_TCP_TCPPROTOCOL_H_

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include "Protocol.h"
#include "flow/FlowManager.h"
#include "flow/FlowCache.h"
#include "FlowForwarder.h"
#include "Cache.h"
#include "TCPStates.h"
#include "TCPInfo.h"

namespace aiengine {

class TCPProtocol: public Protocol {
public:
	explicit TCPProtocol(const std::string &name):
		Protocol(name) {}
	explicit TCPProtocol():TCPProtocol(TCPProtocol::default_name) {}
    	virtual ~TCPProtocol(); 

#if defined(HAVE_TCP_QOS_METRICS)
        class TCPQOSStatistics {
        public:
                explicit TCPQOSStatistics() {
                        std::time(&packet_time);
                        std::time(&last_packet_time);
                }
                virtual ~TCPQOSStatistics() {}

		mutable uint64_t prev_total_connection_setup_time_ = 0;
		mutable uint64_t prev_total_server_reset_rate_ = 0;
		mutable uint64_t prev_total_application_response_time_ = 0;
		mutable std::time_t packet_time;
		mutable std::time_t last_packet_time;
        };
#endif
	static constexpr const char *const default_name = "TCP";
	static const uint16_t header_size = 20;

	uint16_t getId() const override { return IPPROTO_TCP; }
	uint16_t getHeaderSize() const override { return header_size; }

	// Condition for say that a packet is tcp 
	bool check(const Packet &packet) override; 
	void processFlow(Flow *flow) override {}; // This protocol generates flows but not for destination.
	bool processPacket(Packet &packet) override;

	void statistics(std::basic_ostream<char> &out, int level, int32_t limit) const override;
	void statistics(Json &out, int level) const override;

        void releaseCache() override {} // No need to free cache

        void setHeader(const uint8_t *raw_packet) override {
        
                header_ = reinterpret_cast <const struct tcphdr*> (raw_packet);
        }

#if defined(IS__FREEBSD) || defined(IS_OPENBSD) || defined(IS_DARWIN)
    	uint16_t getSourcePort() const { return ntohs(header_->th_sport); }
    	uint16_t getDestinationPort() const { return ntohs(header_->th_dport); }
    	uint32_t getSequence() const  { return ntohl(header_->th_seq); }
    	uint32_t getAckSequence() const  { return ntohl(header_->th_ack); }
    	bool isSyn() const { return (header_->th_flags & TH_SYN) == TH_SYN; }
    	bool isFin() const { return (header_->th_flags & TH_FIN) == TH_FIN; }
    	bool isAck() const { return (header_->th_flags & TH_ACK) == TH_ACK; }
    	bool isRst() const { return (header_->th_flags & TH_RST) == TH_RST; }
    	bool isPushSet() const { return (header_->th_flags & TH_PUSH) == TH_PUSH; }
    	uint16_t getTcpHdrLength() const { return header_->th_off * 4; }
	uint16_t getWindowSize() const { return ntohs(header_->th_win); }
#else
    	bool isSyn() const { return header_->syn == 1; }
    	bool isFin() const { return header_->fin == 1; }
    	bool isAck() const { return header_->ack == 1; }
    	bool isRst() const { return header_->rst == 1; }
    	bool isPushSet() const { return header_->psh == 1; }
    	uint32_t getSequence() const  { return ntohl(header_->seq); }
    	uint32_t getAckSequence() const  { return ntohl(header_->ack_seq); }
    	uint16_t getSourcePort() const { return ntohs(header_->source); }
    	uint16_t getDestinationPort() const { return ntohs(header_->dest); }
    	uint16_t getTcpHdrLength() const { return header_->doff * 4; }
	uint16_t getWindowSize() const { return ntohs(header_->window); }
#endif
    	uint8_t *getPayload() const { return (uint8_t*)header_ + getTcpHdrLength(); }

        void setFlowManager(FlowManagerPtr flow_mng) { flow_table_ = flow_mng; flow_table_->setTCPInfoCache(tcp_info_cache_); }
        FlowManagerPtr getFlowManager() const { return flow_table_; }

        void setFlowCache(FlowCachePtr flow_cache) { flow_cache_ = flow_cache; } 
        FlowCachePtr getFlowCache() const { return flow_cache_;}

	void setRegexManager(const SharedPointer<RegexManager>& rm) { rm_ = rm; }

        void createTCPInfos(int number) { tcp_info_cache_->create(number); }
        void destroyTCPInfos(int number) { tcp_info_cache_->destroy(number); }
#if defined(HAVE_REJECT_FLOW)
	void addRejectFunction(std::function <void (Flow*)> reject) { reject_func_ = reject; } 
#endif
	Flow *getCurrentFlow() const { return current_flow_; } // used just for testing pourposes

        void increaseAllocatedMemory(int value) override;
        void decreaseAllocatedMemory(int value) override;

	uint64_t getCurrentUseMemory() const override;
	uint64_t getAllocatedMemory() const override;
	uint64_t getTotalAllocatedMemory() const override;

        void setDynamicAllocatedMemory(bool value) override; 
        bool isDynamicAllocatedMemory() const override; 

	uint32_t getTotalCacheMisses() const override;
	uint32_t getTotalEvents() const override { return total_events_; }

	CounterMap getCounters() const override; 
	void resetCounters() override;

	void setAnomalyManager(SharedPointer<AnomalyManager> amng) override { anomaly_ = amng; }

#if defined(STAND_ALONE_TEST) || defined(TESTING)
        Cache<TCPInfo>::CachePtr getTCPInfoCache() const { return tcp_info_cache_; }
#endif

private:
        SharedPointer<Flow> getFlow(const Packet &packet);
	void compute_tcp_state(TCPInfo *info, int32_t bytes);
#if defined(HAVE_TCP_QOS_METRICS)
	void compute_qos_rate(uint64_t &connection_time, uint64_t &application_time, uint64_t &reset_rate) const;
#endif
	FlowManagerPtr flow_table_ = nullptr;
	FlowCachePtr flow_cache_ = nullptr;
	SharedPointer<RegexManager> rm_ = nullptr;
	Cache<TCPInfo>::CachePtr tcp_info_cache_ = Cache<TCPInfo>::CachePtr(new Cache<TCPInfo>("TCP info cache"));
	Flow *current_flow_ = nullptr;
	const struct tcphdr *header_ = nullptr;
	uint32_t total_events_ = 0;
	uint32_t total_flags_syn_ = 0;
	uint32_t total_flags_synack_ = 0;
	uint32_t total_flags_ack_ = 0;
	uint32_t total_flags_rst_ = 0;
	uint32_t total_flags_fin_ = 0;
#if defined(HAVE_TCP_QOS_METRICS)
	TCPQOSStatistics qos_ {};
        uint64_t total_server_reset_rate_ = 0;
        uint64_t total_connection_setup_time_ = 0;
        uint64_t total_application_response_time_ = 0;
#endif
       	std::time_t last_timeout_ = 0;
       	std::time_t packet_time_ = 0;
#if defined(HAVE_REJECT_FLOW)
	std::function <void (Flow*)> reject_func_ = [] (Flow*) {};
#endif
	SharedPointer<AnomalyManager> anomaly_ = nullptr;
};

typedef std::shared_ptr<TCPProtocol> TCPProtocolPtr;

} // namespace aiengine

#endif  // SRC_PROTOCOLS_TCP_TCPPROTOCOL_H_
